import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os
import torch
import numpy as np
import torch.utils.tensorboard as tb

from isaacgymenvs.ddim.runners.diffusion import Diffusion


torch.set_printoptions(sci_mode=False)


def parse_args_and_config():
    parser = argparse.ArgumentParser(description=globals()["__doc__"])

    parser.add_argument(
        "--config", type=str, required=True, help="Path to the config file"
    )
    parser.add_argument("--seed", type=int, default=1234, help="Random seed")
    parser.add_argument(
        "--exp", type=str, default="exp", help="Path for saving running related data."
    )
    parser.add_argument(
        "--model_type", type=str, default="diffusion", 
    )
    parser.add_argument(
        "--doc",
        type=str,
        required=True,
        help="A string for documentation purpose. "
        "Will be the name of the log folder.",
    )
    parser.add_argument(
        "--comment", type=str, default="", help="A string for experiment comment"
    )
    parser.add_argument(
        "--verbose",
        type=str,
        default="info",
        help="Verbose level: info | debug | warning | critical",
    )
    parser.add_argument("--test", action="store_true", help="Whether to test the model")
    parser.add_argument(
        "--sample",
        action="store_true",
        help="Whether to produce samples from the model",
    )
    parser.add_argument("--fid", action="store_true")
    parser.add_argument("--interpolation", action="store_true")
    parser.add_argument(
        "--resume_training", action="store_true", help="Whether to resume training"
    )
    parser.add_argument(
        "-i",
        "--image_folder",
        type=str,
        default="images",
        help="The folder name of samples",
    )
    parser.add_argument(
        "--ni",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument("--use_pretrained", action="store_true")
    parser.add_argument(
        "--sample_type",
        type=str,
        default="generalized",
        help="sampling approach (generalized or ddpm_noisy)",
    )
    parser.add_argument(
        "--skip_type",
        type=str,
        default="uniform",
        help="skip according to (uniform or quadratic)",
    )
    parser.add_argument(
        "--timesteps", type=int, default=1000, help="number of steps involved"
    )
    parser.add_argument(
        "--eta",
        type=float,
        default=0.0,
        help="eta used to control the variances of sigma",
    )
    parser.add_argument("--sequence", action="store_true")
    
    parser.add_argument(
        "--data_type",
        type=str,
        default="tracking",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="../../assets/traj_idx_to_demo_dict_9.npy",
    )
    parser.add_argument(
        "--future_type",
        type=str,
        default="hand_motion",
    )
    parser.add_argument("--future_ref_dim", type=int, default=16)
    parser.add_argument(
        "--invdyn_obs_type",
        type=str,
        default="hand_qpos_qtars",
    )
    parser.add_argument(
        "--history_obs_dim",
        type=int,
        default=32,
    )
    parser.add_argument(
        "--w_obj_state_history",
        action="store_true",
    )
    parser.add_argument(
        "--use_stochastic_dataset",
        action="store_true",
    )
    parser.add_argument(
        "--model_arch",
        type=str,
        default="resmlp",
    )
    parser.add_argument(
        "--res_blocks",
        type=int,
        default=2,
    ) # 
    parser.add_argument(
        "--history_length",
        type=int,
        default=4,
    )
    parser.add_argument(
        "--future_length",
        type=int,
        default=2,
    )
    parser.add_argument(
        "--action_type",
        type=str,
        default="absolute",
    ) 
    parser.add_argument(
        "--mask_out_obj_motion",
        action="store_true",
    ) 
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
    ) 
    parser.add_argument(
        "--pred_extrin",
        action="store_true",
    )  
    parser.add_argument(
        "--use_relative_target",
        action="store_true",
    ) 
    parser.add_argument(
        "--normalize_input",
        action="store_true",
    )  
    parser.add_argument(
        "--normalize_output",
        action="store_true",
    ) 
    parser.add_argument(
        "--use_obj_motion_norm_command",
        action="store_true",
    ) 
    parser.add_argument(
        "--obj_state_predictor",
        action="store_true",
    ) 
    parser.add_argument(
        "--future_act_dim",
        type=int,
        default=16,
    ) 
    parser.add_argument(
        "--token_mimicking",
        action="store_true",
    ) 
    parser.add_argument(
        "--masked_cond",
        action="store_true",
    ) 
    parser.add_argument(
        "--obj_motion_format",
        type=str,
        default="ori_motion",
    )
    parser.add_argument(
        "--load_experience_via_mode",
        action="store_true",
    )
    parser.add_argument(
        "--rnd_masked_cond",
        action="store_true",
    )
    parser.add_argument(
        "--invdyn_masked_obj_motion_cond",
        action="store_true",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=32,
    ) 
    parser.add_argument(
        "--encode_data",
        action="store_true",
    )
    parser.add_argument(
        "--train_tokenizer",
        action="store_true",
    )
    parser.add_argument(
        "--train_ddp",
        action="store_true",
    )
    parser.add_argument(
        "--resume_tokenizer",
        type=str,
        default="",
    )
    parser.add_argument(
        "--mixed_sim_real_experiences",
        action="store_true",
    )
    parser.add_argument(
        "--real_experiences_traj_idx_to_file_name",
        type=str,
        default="",
    )

    parser.add_argument(
        "--world_model",
        action="store_true",
    )
    
    parser.add_argument(
        "--train_delta_action_model",
        action="store_true",
    )

    parser.add_argument(
        "--sim_world_model_path",
        type=str,
        default="",
    )
    
    parser.add_argument(
        "--real_world_model_path",
        type=str,
        default="",
    )
    parser.add_argument(
        "--delta_action_model_path",
        type=str,
        default="",
    )
    parser.add_argument(
        "--finger_idx",
        type=int,
        default=-1,
    )
    parser.add_argument(
        "--joint_idx",
        type=int,
        default=-1,
    ) 
    
    parser.add_argument(
        "--optimize_via_fingertip_pos",
        action="store_true",
    )
    

    parser.add_argument(
        "--wm_history_length",
        type=int,
        default=1,
    ) 
    
    parser.add_argument(
        "--train_obj_motion_pred_model",
        action="store_true",
    ) 
    
    parser.add_argument(
        "--hist_context_length",
        type=int,
        default=0,
    ) 
    parser.add_argument(
        "--add_obs_noise_scale",
        type=float,
        default=0.0,
        help="add_obs_noise_scale",
    )
    parser.add_argument(
        "--add_action_noise_scale",
        type=float,
        default=0.0,
        help="add_action_noise_scale",
    )
    
    parser.add_argument(
        "--hist_context_finger_idx",
        type=int,
        default=-1,
    )
    
    parser.add_argument(
        "--train_q_value_model",
        action="store_true",
    ) 
    
    parser.add_argument(
        "--q_value_model_expectile_regression",
        action="store_true",
    )
    parser.add_argument(
        "--wm_pred_joint_idx",
        type=int,
        default=-1,
    )
    parser.add_argument(
        "--train_finger_pos_tracking_model",
        action="store_true",
    )
    parser.add_argument(
        "--finger_pos_tracking_target_finger_idx",
        type=int,
        default=0,
    )
    
    parser.add_argument(
        "--train_inverse_dynamics_model",
        action="store_true",
    ) 
    parser.add_argument(
        "--train_value_network",
        action="store_true",
    ) 
    parser.add_argument(
        "--add_noise_onto_hist_obs",
        action="store_true",
    ) 
    parser.add_argument(
        "--hist_obs_nosie_scale",
        type=float,
        default=0.02,
    )
    parser.add_argument(
        "--train_residual_wm",
        action="store_true",
    ) 
    parser.add_argument(
        "--prev_wm_ckpt",
        type=str,
        default="",
    )
    parser.add_argument(
        "--test_world_model_w_compensator",
        action="store_true",
    )
    
    parser.add_argument(
        "--train_world_model_via_invdyn",
        action="store_true",
    )
    
    parser.add_argument(
        "--n_epochs",
        type=int,
        default=10000,
    )
    parser.add_argument(
        "--eval_world_model_ckpt",
        type=str,
        default="",
    )
    parser.add_argument(
        "--w_hand_root_ornt",
        action="store_true",
    )
    parser.add_argument(
        "--add_nearing_neighbour",
        action="store_true",
    )
    # 
    parser.add_argument(
        "--add_nearing_finger",
        action="store_true",
    )
    parser.add_argument(
        "--wm_as_invdyn_prediction",
        action="store_true",
    )
    parser.add_argument(
        "--stack_wm_history",
        action="store_true",
    )
    parser.add_argument(
        "--multi_joint_single_wm",
        action="store_true",
    )
    
    parser.add_argument(
        "--load_pretrained_wm_ckpt",
        type=str,
        default="",
    )

    parser.add_argument(
        "--delta_action_scale",
        type=float,
        default=0.041666666666666664,
    )
    
    parser.add_argument(
        "--multi_finger_single_wm",
        action="store_true",
    )
    
    
    parser.add_argument(
        "--eval_use_test_set",
        action="store_true",
    )
    
    parser.add_argument(
        "--single_hand_wm",
        action="store_true",
    )
    
    parser.add_argument(
        "--use_sepearate_test_data",
        action="store_true",
    )
    
    parser.add_argument(
         "--seperate_test_data_fn",
        type=str,
        default="",
    )
    # logging_step_interval
    parser.add_argument(
        "--logging_step_interval",
        type=int,
        default=20000,
    )
    
    parser.add_argument(
        "--multi_joint_single_shared_wm",
        action="store_true",
    )
    parser.add_argument(
        "--finetune_policy_w_world_model",
        action="store_true",
    )
    parser.add_argument(
        "--fullhand_wobjstate_wm",
        action="store_true",
    ) 
    parser.add_argument(
        "--pred_nearing_joint",
        action="store_true",
    )
    
    args = parser.parse_args()
    args.log_path = os.path.join(args.exp, "logs", args.doc)

    
    with open(os.path.join("configs", args.config), "r") as f:
        config = yaml.safe_load(f)
    new_config = dict2namespace(config)
    
    new_config.invdyn.future_ref_dim = args.future_ref_dim
    new_config.invdyn.obs_type = args.invdyn_obs_type
    new_config.invdyn.history_obs_dim = args.history_obs_dim
    new_config.invdyn.w_obj_state_history = args.w_obj_state_history 
    new_config.invdyn.model_arch = args.model_arch 
    new_config.invdyn.res_blocks = args.res_blocks
    new_config.invdyn.history_length = args.history_length
    new_config.invdyn.action_type = args.action_type
    new_config.invdyn.mask_out_obj_motion = args.mask_out_obj_motion
    new_config.training.batch_size = args.batch_size
    new_config.invdyn.pred_extrin = args.pred_extrin
    new_config.invdyn.use_relative_target = args.use_relative_target
    new_config.invdyn.future_length = args.future_length
    new_config.invdyn.normalize_input = args.normalize_input
    new_config.invdyn.normalize_output = args.normalize_output
    new_config.invdyn.use_obj_motion_norm_command = args.use_obj_motion_norm_command
    new_config.invdyn.obj_state_predictor = args.obj_state_predictor
    new_config.invdyn.future_act_dim = args.future_act_dim
    new_config.invdyn.token_mimicking = args.token_mimicking
    new_config.invdyn.masked_cond = args.masked_cond
    new_config.invdyn.obj_motion_format = args.obj_motion_format # 
    new_config.invdyn.load_experience_via_mode = args.load_experience_via_mode
    new_config.invdyn.rnd_masked_cond = args.rnd_masked_cond
    new_config.invdyn.invdyn_masked_obj_motion_cond = args.invdyn_masked_obj_motion_cond
    new_config.invdyn.num_workers = args.num_workers
    new_config.invdyn.encode_data = args.encode_data
    new_config.invdyn.train_tokenizer = args.train_tokenizer
    new_config.invdyn.resume_tokenizer = args.resume_tokenizer
    new_config.invdyn.mixed_sim_real_experiences = args.mixed_sim_real_experiences
    new_config.invdyn.real_experiences_traj_idx_to_file_name = args.real_experiences_traj_idx_to_file_name
    new_config.invdyn.finger_idx = args.finger_idx
    new_config.invdyn.joint_idx = args.joint_idx
    new_config.invdyn.optimize_via_fingertip_pos = args.optimize_via_fingertip_pos
    new_config.invdyn.wm_history_length = args.wm_history_length
    new_config.invdyn.train_obj_motion_pred_model  = args.train_obj_motion_pred_model
    new_config.invdyn.hist_context_length = args.hist_context_length
    new_config.invdyn.add_obs_noise_scale = args.add_obs_noise_scale
    new_config.invdyn.add_action_noise_scale = args.add_action_noise_scale
    new_config.invdyn.hist_context_finger_idx = args.hist_context_finger_idx
    new_config.invdyn.train_q_value_model = args.train_q_value_model
    new_config.invdyn.q_value_model_expectile_regression = args.q_value_model_expectile_regression
    new_config.invdyn.wm_pred_joint_idx = args.wm_pred_joint_idx 
    new_config.invdyn.train_finger_pos_tracking_model = args.train_finger_pos_tracking_model
    new_config.invdyn.finger_pos_tracking_target_finger_idx = args.finger_pos_tracking_target_finger_idx
    new_config.invdyn.train_inverse_dynamics_model = args.train_inverse_dynamics_model # 
    new_config.invdyn.train_value_network = args.train_value_network
    new_config.invdyn.add_noise_onto_hist_obs = args.add_noise_onto_hist_obs
    new_config.invdyn.hist_obs_nosie_scale = args.hist_obs_nosie_scale
    new_config.invdyn.train_residual_wm = args.train_residual_wm
    new_config.invdyn.prev_wm_ckpt = args.prev_wm_ckpt
    new_config.training.n_epochs = args.n_epochs
    new_config.invdyn.w_hand_root_ornt = args.w_hand_root_ornt
    new_config.invdyn.add_nearing_neighbour = args.add_nearing_neighbour
    new_config.invdyn.add_nearing_finger = args.add_nearing_finger
    new_config.invdyn.wm_as_invdyn_prediction = args.wm_as_invdyn_prediction
    new_config.invdyn.stack_wm_history = args.stack_wm_history
    new_config.invdyn.multi_joint_single_wm = args.multi_joint_single_wm
    new_config.invdyn.load_pretrained_wm_ckpt = args.load_pretrained_wm_ckpt
    new_config.invdyn.delta_action_scale = args.delta_action_scale
    new_config.invdyn.multi_finger_single_wm = args.multi_finger_single_wm
    new_config.invdyn.single_hand_wm = args.single_hand_wm
    new_config.invdyn.use_sepearate_test_data = args.use_sepearate_test_data
    new_config.invdyn.seperate_test_data_fn = args.seperate_test_data_fn
    new_config.training.logging_step_interval = args.logging_step_interval
    new_config.invdyn.multi_joint_single_shared_wm = args.multi_joint_single_shared_wm
    new_config.invdyn.fullhand_wobjstate_wm = args.fullhand_wobjstate_wm
    new_config.invdyn.pred_nearing_joint = args.pred_nearing_joint
    
    
    if args.train_ddp:
        cuda_device_idx = int(os.environ["LOCAL_RANK"])
    
        print(f"cuda_device_idx: {cuda_device_idx}")
    
    
    tb_path = os.path.join(args.exp, "tensorboard", args.doc)

    if not args.test and not args.sample:
        if not args.resume_training:
            
            if (not args.train_ddp) or cuda_device_idx == 0: 
                if os.path.exists(args.log_path):
                    overwrite = False
                    if args.ni:
                        overwrite = True
                    else:
                        response = input("Folder already exists. Overwrite? (Y/N)")
                        if response.upper() == "Y":
                            overwrite = True

                    
                    if overwrite:
                        shutil.rmtree(args.log_path)
                        shutil.rmtree(tb_path)
                        os.makedirs(args.log_path)
                        if os.path.exists(tb_path):
                            shutil.rmtree(tb_path)
                    else:
                        print("Folder exists. Program halted.")
                        sys.exit(0)
                else:
                    os.makedirs(args.log_path)

                with open(os.path.join(args.log_path, "config.yml"), "w") as f:
                    yaml.dump(new_config, f, default_flow_style=False)
        
        # if args.train_ddp and cuda_device_idx == 0:
        new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)
        # 
        level = getattr(logging, args.verbose.upper(), None)
        if not isinstance(level, int):
            raise ValueError("level {} not supported".format(args.verbose))

        handler1 = logging.StreamHandler()
        handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt"))
        formatter = logging.Formatter(
            "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
        )
        handler1.setFormatter(formatter)
        handler2.setFormatter(formatter)
        logger = logging.getLogger()
        logger.addHandler(handler1)
        logger.addHandler(handler2)
        logger.setLevel(level)

    else:
        level = getattr(logging, args.verbose.upper(), None)
        if not isinstance(level, int):
            raise ValueError("level {} not supported".format(args.verbose))

        handler1 = logging.StreamHandler()
        formatter = logging.Formatter(
            "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
        )
        handler1.setFormatter(formatter)
        logger = logging.getLogger()
        logger.addHandler(handler1)
        logger.setLevel(level)

        if args.sample:
            os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True)
            args.image_folder = os.path.join(
                args.exp, "image_samples", args.image_folder
            )
            if not os.path.exists(args.image_folder):
                os.makedirs(args.image_folder)
            else:
                if not (args.fid or args.interpolation):
                    overwrite = False
                    if args.ni:
                        overwrite = True
                    else:
                        response = input(
                            f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
                        )
                        if response.upper() == "Y":
                            overwrite = True

                    if overwrite:
                        shutil.rmtree(args.image_folder)
                        os.makedirs(args.image_folder)
                    else:
                        print("Output image folder exists. Program halted.")
                        sys.exit(0)

    
    if not args.train_ddp:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        logging.info("Using device: {}".format(device))
        new_config.device = device
    
    new_config.data_type = args.data_type

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    return args, new_config


def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace


def main():
    args, config = parse_args_and_config()
    logging.info("Writing log file to {}".format(args.log_path))
    logging.info("Exp instance id = {}".format(os.getpid()))
    logging.info("Exp comment = {}".format(args.comment))

    try: 
        if config.data.dataset == "controlseq":
            from runners.diffusion_controlseq import Diffusion as DiffusionControlSeq
            runner = DiffusionControlSeq(args, config)
        else:
            runner = Diffusion(args, config)
        if args.model_type == 'invdyn':
            if args.sample:
                if args.train_delta_action_model:
                    runner.sample_world_model_delta_actions()
                elif args.world_model:
                    runner.eval_world_model(args.eval_world_model_ckpt)
                else:
                    runner.sample_invdyn()
            else:
                if args.test_world_model_w_compensator:
                    runner.test_world_model_w_compensator()
                elif args.train_inverse_dynamics_model:
                    runner.train_inverse_dynamics_model()
                elif args.train_q_value_model:
                    runner.train_q_value_model()
                elif args.train_delta_action_model:
                    if args.train_ddp:
                        if args.finetune_policy_w_world_model:
                            from runners.diffusion_controlseq import finetune_policy_w_world_model_ddp
                            finetune_policy_w_world_model_ddp(runner)
                        else:
                            from runners.diffusion_controlseq import train_world_model_delta_actions_ddp
                            train_world_model_delta_actions_ddp(runner)
                    else:
                        runner.train_world_model_delta_actions()
                elif args.world_model:
                    if args.train_ddp:
                        from runners.diffusion_controlseq import train_world_model_ddp
                        train_world_model_ddp(runner)
                    else:
                        runner.train_world_model()
                elif args.token_mimicking:
                    runner.train_invdyn_tokenmimicking()
                else:
                    if args.train_ddp:
                        # import torch.multiprocessing as mp
                        from runners.diffusion_controlseq import train_invdyn_ddp
                        # world_size = torch.cuda.device_count()
                        # mp.spawn(train_invdyn_ddp, args=(world_size, runner), nprocs=world_size, join=True)
                        train_invdyn_ddp(runner)
                    else:
                        runner.train_invdyn()
        elif args.sample:
            runner.sample()
        elif args.test:
            runner.test()
        else:
            runner.train()
    except Exception:
        logging.error(traceback.format_exc())

    return 0


if __name__ == "__main__":
    sys.exit(main())
